package org.apache.gluten.execution

import io.substrait.proto.CrossRel
import org.apache.gluten.extension.ValidationResult
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

case class OmniBroadcastNestedLoopJoinExecTransformer(
    left: SparkPlan, 
    right: SparkPlan, 
    buildSide: BuildSide, 
    joinType: JoinType, 
    condition: Option[Expression])
  extends BroadcastNestedLoopJoinExecTransformer(
    left, 
    right, 
    buildSide, 
    joinType, 
    condition) {

  override protected lazy val substraitJoinType: CrossRel.JoinType = joinType match {
    case Inner =>
      CrossRel.JoinType.JOIN_TYPE_INNER
    case FullOuter =>
      CrossRel.JoinType.JOIN_TYPE_OUTER
    case LeftOuter =>
      CrossRel.JoinType.JOIN_TYPE_LEFT
    case RightOuter =>
      CrossRel.JoinType.JOIN_TYPE_RIGHT
    case _ =>
      CrossRel.JoinType.UNRECOGNIZED
  }

  override def validateJoinTypeAndBuildSide(): ValidationResult = {
    val result = joinType match {
      case Inner | LeftOuter | RightOuter => ValidationResult.succeeded
      case _ =>
        ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin")
    }

    if (!result.ok()) {
      return result
    }

    (joinType, buildSide) match {
      case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) =>
        ValidationResult.failed(s"$joinType join is not supported with $buildSide")
      case _ => ValidationResult.succeeded // continue
    }
  }
    
  override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
    val streamedRDD = getColumnarInputRDDs(streamedPlan)
    val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
    val broadcastRDD = OmniBroadcastBuildSideRDD(sparkContext, broadcast)
    streamedRDD :+ broadcastRDD
  }

  override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): OmniBroadcastNestedLoopJoinExecTransformer = 
    copy(left = newLeft, right = newRight)
}